package com.wstuo.common.config.backup.service;

import java.io.FileNotFoundException;
import java.io.InputStream;

import java.io.IOException;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;

import javax.sql.DataSource;

import org.apache.log4j.Logger;
import org.dbunit.DatabaseUnitException;
import org.dbunit.database.DatabaseConfig;
import org.dbunit.database.DatabaseConnection;
import org.dbunit.database.DefaultMetadataHandler;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.CachedDataSet;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.stream.IDataSetProducer;
import org.dbunit.dataset.xml.XmlProducer;
import org.dbunit.ext.mysql.MySqlDataTypeFactory;
import org.dbunit.operation.DatabaseOperation;
import org.h2.tools.SimpleResultSet;
import org.h2.tools.SimpleRowSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.annotation.Transactional;
import org.xml.sax.InputSource;

public class DBXMLHelper {

	private static final Logger LOGGER = Logger.getLogger(DBXMLHelper.class);
	@Autowired
	private DataSource dataSource;

	/**
	 * 获取连接
	 * 
	 * @return
	 * @throws DatabaseUnitException
	 */
	private IDatabaseConnection getConnection() {

		Connection con = null;
		IDatabaseConnection connection = null;
		try {
			con = dataSource.getConnection();
			/*
			 * Mysql 连接
			 */
			connection = new DatabaseConnection(con);
			connection.getConnection()
					.prepareStatement("set @@session.foreign_key_checks = 0")
					.execute();
			connection.getConfig().setProperty(
					DatabaseConfig.PROPERTY_DATATYPE_FACTORY,
					new MySqlDataTypeFactory());
			connection.getConfig().setProperty(
					DatabaseConfig.PROPERTY_METADATA_HANDLER,
					new MyDefaultMetadataHandler());

			// oracle
			/*
			 * connection= new OracleConnection(con,"PHIL1");
			 * connection.getConnection
			 * ().prepareStatement("set @@session.foreign_key_checks = 0"
			 * ).execute(); connection.getConfig().setProperty(DatabaseConfig.
			 * PROPERTY_DATATYPE_FACTORY, new Oracle10DataTypeFactory());
			 * connection
			 * .getConfig().setProperty(DatabaseConfig.PROPERTY_METADATA_HANDLER
			 * , new MyDefaultMetadataHandler());
			 * 
			 * 
			 * mssql connection = new DatabaseConnection(con);
			 * connection.getConnection
			 * ().prepareStatement("set @@session.foreign_key_checks = 0"
			 * ).execute(); connection.getConfig().setProperty(DatabaseConfig.
			 * PROPERTY_DATATYPE_FACTORY, new MsSqlDataTypeFactory());
			 * connection
			 * .getConfig().setProperty(DatabaseConfig.PROPERTY_METADATA_HANDLER
			 * , new MyDefaultMetadataHandler());
			 */
		} catch (SQLException e) {
			LOGGER.error("SQLException", e);
		} catch (DatabaseUnitException e) {
			LOGGER.error("DatabaseUnitException", e);
		}
		return connection;

	}

	/**
	 * 导出数据库.
	 * 
	 * @param fileName
	 * @throws SQLException
	 * @throws DatabaseUnitException
	 * @throws FileNotFoundException
	 * @throws IOException
	 */
	@Transactional
	public void exportDatabase(String fileName) throws SQLException,
			DatabaseUnitException, FileNotFoundException, IOException {

		IDatabaseConnection conn = null;
		InputStream inputStream = null;
		try {
			/**
			 * 方法二 命令备份 以.sql格式结尾
			 * */
			Properties properties = null;

			inputStream = DBXMLHelper.class.getClassLoader()
					.getResourceAsStream("hibernate.properties");

			properties = new Properties();
			properties.load(inputStream);

			Runtime rt = Runtime.getRuntime();
			String uname = properties.getProperty("dataSource.username");
			String upwd = properties.getProperty("dataSource.password");
			String udatabase = properties
					.getProperty("dataSource.databaseName");
			String stmt1 = "mysqldump " + udatabase + " -u " + uname + " -p"
					+ upwd + " --result-file=" + fileName;

			rt.exec(stmt1); // 设置导出编码为utf8。这里必须是utf8
		} finally {
			if(inputStream!=null){
				inputStream.close();
			}
			if (conn != null) {
				conn.close();
			}
		}

	}

	/**
	 * 直接还原数据库（不清除数据）.
	 * 
	 * @param fileName
	 * @throws DatabaseUnitException
	 * @throws SQLException
	 */
	@Transactional
	public void refreshData(String fileName) {
		// String input=dto.getFileName();
		IDatabaseConnection conn = null;
		IDataSetProducer producer = null;
		IDataSet dataset = null;
		try {
			conn = getConnection();

			/*
			 * Mysql 还原
			 */
			producer = new XmlProducer(new InputSource(fileName));
			dataset = new CachedDataSet(producer);
			DatabaseOperation.REFRESH.execute(conn, dataset);

			// oracle
			/*
			 * producer = new XmlProducer(new InputSource(input)); dataset = new
			 * CachedDataSet(producer);
			 * DatabaseOperation.CLEAN_INSERT.execute(conn, dataset);
			 */
			// MSSQL
			/*
			 * producer = new XmlProducer(new InputSource(input)); dataset = new
			 * CachedDataSet(producer);
			 * org.dbunit.ext.mssql.InsertIdentityOperation
			 * .REFRESH.execute(conn, dataset);
			 */

		} catch (Exception ex) {
			LOGGER.error("refreshData Exception", ex);
		} finally {
			if (conn != null) {
				try {
					conn.close();
				} catch (SQLException e) {
					LOGGER.error("refreshData SQLException", e);
				}
			}
		}
	}

	static class MyDefaultMetadataHandler extends DefaultMetadataHandler {
		public ResultSet getPrimaryKeys(DatabaseMetaData metaData,
				String schemaName, String tableName) throws SQLException {
			ResultSet resultSet = super.getPrimaryKeys(metaData, schemaName,
					tableName);
			if (resultSet.next()) {
				resultSet.close();
				resultSet = super.getPrimaryKeys(metaData, schemaName,
						tableName);
			} else {
				resultSet.close();
				ResultSet pkRS = super.getColumns(metaData, schemaName,
						tableName);
				List<Object[]> list = new ArrayList<Object[]>();
				SimpleResultSet simpleResultSet = new SimpleResultSet(
						new MySimpleRowSource(list));
				int i = 1;
				boolean isInit = false;
				try {
					while (pkRS.next()) {
						if (!isInit) {
							ResultSetMetaData md = pkRS.getMetaData();
							simpleResultSet.addColumn("TABLE_CAT",
									md.getColumnType(1), md.getPrecision(1),
									md.getScale(1));
							simpleResultSet.addColumn("TABLE_SCHEM",
									md.getColumnType(2), md.getPrecision(2),
									md.getScale(2));
							simpleResultSet.addColumn("TABLE_NAME",
									md.getColumnType(3), md.getPrecision(3),
									md.getScale(3));
							simpleResultSet.addColumn("COLUMN_NAME",
									md.getColumnType(4), md.getPrecision(4),
									md.getScale(4));
							simpleResultSet.addColumn("KEY_SEQ",
									md.getColumnType(5), md.getPrecision(5),
									md.getScale(5));
							simpleResultSet.addColumn("COLUMN_NAME",
									md.getColumnType(4), md.getPrecision(4),
									md.getScale(4));
							isInit = true;
						}
						Object[] objs = new Object[] { pkRS.getString(1),
								pkRS.getString(2), pkRS.getString(3),
								pkRS.getString(4), i++, pkRS.getString(4) };

						list.add(objs);
					}
				} finally {
					if(pkRS!=null){
						pkRS.close();
					}
					
				}
				resultSet = simpleResultSet;
			}
			return resultSet;
		}
	}

	static class MySimpleRowSource implements SimpleRowSource {

		List<Object[]> datas = new ArrayList<Object[]>();
		int current;

		public MySimpleRowSource(List<Object[]> datas) {
			this.datas = datas;
			current = 0;
		}

		public Object[] readRow() throws SQLException {
			return datas.size() > current ? datas.get(current++) : null;
		}

		public void close() {
			datas.clear();
		}

		public void reset() throws SQLException {
			current = 0;
		}

	}

}
